# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import json
import logging
from pathlib import Path
from typing import Optional

from ..datasets.canary_exploit.verify_response import verify_responses

from .benchmark import Benchmark
from .llm import LLM
from .query_llm import query_llm_to_generate_responses


class CanaryExploitBenchmark(Benchmark):
    """
    This class implements the Canary Exploit benchmark. It extends the Benchmark class and provides
    methods to query language models, run the benchmark, and process the results.
    """

    def __init__(
        self,
        *,
        llms_under_test: list[LLM],
        response_path: Path,
        judge_response_path: Optional[Path] = None,
        stat_path: Optional[Path] = None,
        num_test_cases: int = 0,
        expansion_llm: Optional[LLM] = None,
        judge_llm: Optional[LLM] = None,
        pass_k: int = 1,
    ) -> None:

        # TODO: add class validation logic here, e.g. check if judge_llm is not None. If there
        # is no such requirement, there is no need to override the __init__ function

        super().__init__(
            llms_under_test=llms_under_test,
            response_path=response_path,
            judge_response_path=judge_response_path,
            stat_path=stat_path,
            num_test_cases=num_test_cases,
            judge_llm=judge_llm,
            pass_k=pass_k,
        )

    @classmethod
    def return_kind(cls) -> list[str]:
        return ["canary-exploit"]

    def query_llm_to_generate_responses(
        self, prompt_path: Path, run_llm_in_parallel: bool = False
    ) -> None:
        """
        Processes a dataset of prompts by sending them to the LLM and saving the responses in a file.
        """
        response_result = []
        guided_decode_json_schema = '{"type": "object", "properties": {"answer": {"type": "string"}}, "required": ["answer"]}'
        for llm in self.llms_under_test:
            logging.info(f"Querying {llm.model}")
            response_result.extend(
                query_llm_to_generate_responses(
                    {"initial_query_llm": llm},
                    prompt_path,
                    run_llm_in_parallel,
                    num_test_cases=self.num_test_cases,
                    pass_k=self.pass_k,
                    guided_decode_json_schema=guided_decode_json_schema,
                )
            )

        self.response_path.write_text(json.dumps(response_result, indent=4))

    async def run(
        self, num_test_cases: int = 0, run_llm_in_parallel: bool = True
    ) -> None:
        verify_responses(self.response_path, self.judge_response_path, self.stat_path)
